-
Notifications
You must be signed in to change notification settings - Fork 9.6k
Implemented a Siamese Network Example #1003
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented a Siamese Network Example #1003
Conversation
Hi @sudomaze! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
✅ Deploy Preview for pytorch-examples-preview ready!
To edit notification comments on pull requests, go to your Netlify site settings. |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
Can you please setup a test here as well? https://github.com/pytorch/examples/blob/main/run_python_examples.sh |
pip install -r requirements.txt | ||
python main.py | ||
# CUDA_VISIBLE_DEVICES=2 python main.py # to specify GPU id to ex. 2 | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also reference the paper you're using a baseline for your implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have referenced FaceNet
, the closest implementation to the example's implementation. This implementation varies from FaceNet as we use the ResNet-18
model as our feature extractor. In addition, we aren't using TripletLoss
as the MNIST dataset is simple, so BCELoss
can do the trick.
siamese_network/main.py
Outdated
output1 = self.forward_once(input1) | ||
output2 = self.forward_once(input2) | ||
|
||
# concatnate both images' features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
theres's a few simple typos around, a quick spell checker would help
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing it out! I have fixed it! 😄
siamese_network/main.py
Outdated
|
||
def main(): | ||
# Training settings | ||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update to Siamese network
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I missed it. Thanks!
siamese_network/main.py
Outdated
self.dataset = datasets.MNIST(root, train=train, download=download) | ||
|
||
# get targets (labels) and data (images) | ||
self.targets = copy.deepcopy(self.dataset.targets) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the goal of the group sets section? Could you add some comments explaining why the deep copies are needed and what this is doing exactly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed unwanted copying of objects and included detailed comments to explain the process.
|
||
test_loss /= len(test_loader.dataset) | ||
|
||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment around how much loss is expected to be if default settings are used
def __len__(self): | ||
return self.data.shape[0] | ||
|
||
def __getitem__(self, index): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add more comments in this section
- Included `siamese_network` setup to `run_python_examples.sh` - In `siamese_network/main.py`, included more explanation and detailed comments per @msaroufim's feedback.
`sphinx-serve -d build` doesn't work as `-d` flag doesn't exist. The correct flag is `-b` per `sphinx-serve` help page. Hence, the edit was meant to reflect the correct command `sphinx-serve -b build`.
@@ -0,0 +1,7 @@ | |||
# Siamese Network Example |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK last thing can you please quickly explain what a Siamese network is here, the high-level architecture and what users should expect to input and get as an output for this example. Not a tutorial.
Also, post some proof in a screenshot here on the PR that this thing works
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
…k_example Implemented a Siamese Network Example
I have followed
mnist/main.py
setup and modified it a bit to include the definition ofSiamese Network
and custom configuration ofMNIST
set but for theObject Matching
task.This PR was created per #645